#include "OWLQN.h"
#include "TerminationCriterion.h"

#include <vector>
#include <deque>
#include <cmath>
#include <iostream>
#include <iomanip>
#include <sstream>

using namespace std;

double OptimizerState::dotProduct(const DblVec& a, const DblVec& b){
	double result = 0;
//	cout << a[0] << " " << b[0] << endl;
	for(size_t i=0; i < a.size(); i++){
		
		result += a[i] * b[i];
	}
	return result;
}

void OptimizerState::addMult(DblVec& a, const DblVec& b, double c){
	for (size_t i = 0; i < a.size(); i++){
		a[i] += b[i] * c;
	}
}

void OptimizerState::add(DblVec& a, const DblVec& b){
	for (size_t i = 0; i <a.size(); i++){
		a[i] += b[i];
	}
}

void OptimizerState::addMultInto(DblVec& a, const DblVec& b, const DblVec& c, double d){
	for (size_t i = 0; i < a.size(); i++){
		a[i] = b[i] + c[i] * d;
	}
} 

void OptimizerState::scale(DblVec& a, double b){
	for (size_t i = 0; i < a.size(); i++){
		a[i] *= b;
	}
} 

void OptimizerState::scaleInto(DblVec& a, const DblVec& b, double c){
	for (size_t i= 0; i < a.size(); i++){
		a[i] = b[i] * c;
	}
}

void OptimizerState::MapDirByInverseHessian() {
	int count = (int)sList.size();
	if (count != 0){
		for (int i = count - 1; i >= 0; i--){
			alphas[i] = -dotProduct(*sList[i], dir) / roList[i];
			addMult(dir, *yList[i], alphas[i]);
		}
		const DblVec& lastY = *yList[count-1];
		double yDotY = dotProduct(lastY, lastY);
		double scalar = roList[count - 1] / yDotY;
		scale(dir, scalar);
		
		for (int i = 0; i < count; i++){
			double beta = dotProduct(*yList[i], dir) / roList[i];
			addMult(dir, *sList[i], -alphas[i] - beta);
		}
	}
}

void OptimizerState::MakeSteepestDescDir() {
	if (l1weight == 0){
		scaleInto(dir, grad, -1);
	}
	else {
		for (size_t i = 0; i < dim; i++){
			if (x[i] < 0){
				dir[i] = -grad[i] + l1weight;
			}
			else if (x[i] > 0){
				dir[i] = -grad[i] - l1weight;
			}
			else {
				if (grad[i] < -l1weight) {
					dir[i] = -grad[i] - l1weight;
				}
				else if (grad[i] > l1weight) {
					dir[i] = -grad[i] + l1weight;
				}
				else {
					dir[i] = 0;
				}
			}
		}
	}
	
	steepestDescDir = dir;
}


void OptimizerState::FixDirSigns(){
	if (l1weight > 0) {
		for (size_t i = 0; i < dim; i++){
			if ( dir[i] * steepestDescDir[i] <= 0){
				dir[i] = 0;
			}
		}
	}
}


void OptimizerState::UpdateDir() {
	MakeSteepestDescDir();
	MapDirByInverseHessian();
	FixDirSigns();
#ifdef _DEBUG_
	TestDirDeriv();
#endif
}


void OptimizerState::TestDirDeriv() {
	//cout << dotProduct(dir, dir) << endl;
	double dirNorm = sqrt(dotProduct(dir, dir));
	if(dirNorm == 0){
		cout << "Zero Direction\n" << endl;
		return;
	}
	double eps = 1.05e-8 / dirNorm;
	GetNextPoint(eps);
	double val2 = EvalL1();
	double numDeriv = (val2 - value) / eps;
	double deriv = DirDeriv();
	cout << "  Grad check: " << numDeriv << " vs. " << deriv << "  ";
}

double OptimizerState::DirDeriv() const {
	if (l1weight == 0){
		return dotProduct(dir, grad);
	}
	else {
		double val = 0.0;
		for (size_t i = 0; i < dim; i++){
			if (dir[i] != 0){
				if(x[i] < 0){
					val += dir[i] * (grad[i] - l1weight);
				}else if(x[i] > 0){
					val += dir[i] * (grad[i] + l1weight);
				}else if(dir[i] < 0){
					val += dir[i] * (grad[i] - l1weight);
				}else if(dir[i] > 0){
					val += dir[i] * (grad[i] + l1weight);
				}
			}
		}
		return val;
	}
	
}

//Constrained line search
void OptimizerState::GetNextPoint(double alpha){
	addMultInto(newX, x, dir, alpha);
	if(l1weight > 0){
		for (size_t i = 0; i < dim; i++){
			if (x[i] * newX[i] < 0.0){
				newX[i] = 0.0;
			}
		}
	}
}

double OptimizerState::EvalL1() {
	func.handler(0, 1);
	double val = func.Eval(newX, newGrad);
	if (l1weight > 0){
		for (size_t i = 0; i < dim; i++){
			val += fabs(newX[i]) * l1weight;
		}
	}
	return val;
}

void OptimizerState::BackTrackingLineSearch() {
	double origDirDeriv = DirDeriv();
	//if a non-descent direction is chosen, the line search will break anyway, so throw here
	//The most likely reason for this is a bug in your function's gradient computation
	if (origDirDeriv > 0){
		cerr << "L-BFGS chose a non-descent direction: check your gradient!" << endl;
		exit(1);
	}
	double alpha = 1.0;
	double backoff = 0.5;
	if (iter == 1){
		double normDir = sqrt(dotProduct(dir, dir));
		alpha = (1 / normDir);
		backoff = 0.1;
	}
	
	const double c1 = 1e-4;
	double oldValue = value;

	while (true){
		GetNextPoint(alpha);
		value = EvalL1();
		if(value <= oldValue + c1 * origDirDeriv * alpha) break;
		
		cout << "." << flush;
		alpha *= backoff;
	}
	cout << endl;
} 

void OptimizerState::Shift() {
	DblVec *nextS = NULL, *nextY = NULL;
	
	int listSize = (int)sList.size();
	
	if (listSize < m) {
		try{
			nextS = new vector<double>(dim);
			nextY = new vector<double>(dim);
		} catch (bad_alloc){
			m = listSize;
			if (nextS != NULL){
				delete nextS;
				nextS = NULL;
			}
		}
	}
	if (nextS == NULL){
		nextS = sList.front();
		sList.pop_front();
		nextY = yList.front();
		yList.pop_front();
		roList.pop_front();
	}
	addMultInto(*nextS, newX, x, -1);
	addMultInto(*nextY, newGrad, grad, -1);
	double ro = dotProduct(*nextS, *nextY);
	if (ro != 0){
		sList.push_back(nextS);
		yList.push_back(nextY);
		roList.push_back(ro);
	}
		
	x.swap(newX);
	grad.swap(newGrad);
	
	iter++;
}


double OWLQN::Minimize(DifferentiableFunction& function, const DblVec& initial, DblVec& minimum, double l1weight,double tol,int m, int iter) const {
	OptimizerState state(function, initial, m, l1weight);
	cout << setprecision(4) << scientific << right;
	cout << endl << "Optimizing function of " << state.dim << " variables with OWL_QN parameters: " << endl;
	cout << "	l1 regularization weitht: " << l1weight << "." << endl;
	cout << "	L-BFGS memory parameter (m): " << m << endl;
	cout << "	Convergence tolerance: " << tol << endl;
	cout << endl;
	cout << "Iter n: new_value (conv_crit) line_search" << endl << flush;
	cout << "Iter 0:   " << setw(10) << state.value << " (**********) " << flush;
	
	
	
	while(true) {
		state.UpdateDir();
		if(state.DirDeriv() == 0) break;
		state.BackTrackingLineSearch();

		ostringstream str;
		double termCritVal = termCrit->GetValue(state, str);
		cout << "Iter " << setw(4) << state.iter << ": "  << setw(10) << state.value;
		cout << str.str() << flush;
		
		if (termCritVal < tol) break;
		
		state.Shift(); 
		if(state.iter >= 500) break;
	}
	
	cout << endl;
	minimum = state.newX;
	cout << "Iter " << setw(4) << state.iter << ": "  << setw(10) << state.value << endl;
	return state.value;
}


